{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Za8-Nr5k11fh"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Eq10uEbw0E4l"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UysiGN3tGQHY"
      },
      "source": [
        "# Running TFLite models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2hOrvdmswy5O"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c01_linear_regression.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c01_linear_regression.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W-VhTkyTGcaQ"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dy4BcTjBFTWx"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import pathlib\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from tensorflow.keras.models import Model\n",
        "from tensorflow.keras.layers import Input"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ceibQLDeGhI4"
      },
      "source": [
        "## Create a basic model of the form y = mx + c"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YIBCsjQNF46Z"
      },
      "outputs": [],
      "source": [
        "# Create a simple Keras model.\n",
        "x = [-1, 0, 1, 2, 3, 4]\n",
        "y = [-3, -1, 1, 3, 5, 7]\n",
        "\n",
        "model = tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Dense(units=1, input_shape=[1])\n",
        "])\n",
        "model.compile(optimizer='sgd', loss='mean_squared_error')\n",
        "model.fit(x, y, epochs=200, verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EjsB-QICGt6L"
      },
      "source": [
        "## Generate a SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a9xcbK7QHOfm"
      },
      "outputs": [],
      "source": [
        "export_dir = 'saved_model/1'\n",
        "tf.saved_model.save(model, export_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RRtsNwkiGxcO"
      },
      "source": [
        "## Convert the SavedModel to TFLite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TtM8yKTVTpD3"
      },
      "outputs": [],
      "source": [
        "# Convert the model.\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)\n",
        "tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4idYulcNHTdO"
      },
      "outputs": [],
      "source": [
        "tflite_model_file = pathlib.Path('model.tflite')\n",
        "tflite_model_file.write_bytes(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HgGvp2yBG25Q"
      },
      "source": [
        "## Initialize the TFLite interpreter to try it out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DOt94wIWF8m7"
      },
      "outputs": [],
      "source": [
        "# Load TFLite model and allocate tensors.\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "# Get input and output tensors.\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JGYkEK08F8qK"
      },
      "outputs": [],
      "source": [
        "# Test the TensorFlow Lite model on random input data.\n",
        "input_shape = input_details[0]['shape']\n",
        "inputs, outputs = [], []\n",
        "for _ in range(100):\n",
        "  input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)\n",
        "  interpreter.set_tensor(input_details[0]['index'], input_data)\n",
        "\n",
        "  interpreter.invoke()\n",
        "  tflite_results = interpreter.get_tensor(output_details[0]['index'])\n",
        "\n",
        "  # Test the TensorFlow model on random input data.\n",
        "  tf_results = model(tf.constant(input_data))\n",
        "  output_data = np.array(tf_results)\n",
        "  \n",
        "  inputs.append(input_data[0][0])\n",
        "  outputs.append(output_data[0][0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t1gQGH1KWAgW"
      },
      "source": [
        "## Visualize the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ccvQ1mEJVrqo"
      },
      "outputs": [],
      "source": [
        "plt.plot(inputs, outputs, 'r')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WbugMH6yKvtd"
      },
      "source": [
        "## Download the TFLite model file"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FOAIMETeJmkc"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  from google.colab import files\n",
        "  files.download(tflite_model_file)\n",
        "except:\n",
        "  pass"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "tflite_c01_linear_regression.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
